#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import configparser
import csv
import hashlib
import logging
import math
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from threading import Lock

import pandas as pd
import pytz
from bson import ObjectId
from pymongo import MongoClient, InsertOne
from pymongo.errors import BulkWriteError
from snowflake import SnowflakeGenerator

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)

# 配置文件路径
config_file_path = 'config.ini'

# 最大重试次数
ID_MAX_RETRIES = 3

# 全局失败记录列表
all_failed_records = []
failed_records_lock = threading.Lock()  # 用于同步多线程对错误记录的操作

# 全局变量
global_progress = 0
progress_lock = Lock()

# 定义csv 行字段类型
dtype = {"isOilIncluded": str, "kilometers": float, "amount": float, "enOilId": str}

# Excel中不需要的列
columns_to_ignore = []

# 匹配日期格式 dd/MM/yyyy HH:mm:ss
# datetime_format = '%d/%m/%Y %H:%M:%S'
# 匹配日期格式 yyyy-MM-dd HH:mm:ss
datetime_format = '%Y-%m-%d %H:%M:%S'


# 业务数据转换
def biz_data_convert(data):
    df = pd.DataFrame(data)

    # 补充固定的字段
    df["_class"] = 'com.wuyang.honda.maintain.entity.MaintainRecord'

    # 添加创建时间
    current_time = datetime.now(pytz.utc).astimezone(pytz.timezone("Asia/Shanghai"))
    df["updateTime"] = current_time

    # 是否包含机油 0：否 1：是
    if "isOilIncluded" in df.columns:
        df["isOilIncluded"] = df["isOilIncluded"].astype(str).str.lower().map({"yes": 1, "no": 0}).fillna(
            df["isOilIncluded"])

    # 保养金额为0的记录，将其转换为0
    if "amount" in df.columns:
        df["amount"] = pd.to_numeric(df["amount"], errors="coerce").fillna(0)

    # 保养公里数为空的记录，将其转换为0
    if "kilometers" in df.columns:
        df["kilometers"] = pd.to_numeric(df["kilometers"], errors="coerce").fillna(0).astype(int)

    # 删除不需要的列
    df.drop(columns=columns_to_ignore, inplace=True, errors="ignore")

    return df.to_dict(orient="records")


# 读取配置文件
def read_config(file_path):
    config = configparser.RawConfigParser()
    config.read(file_path)
    return config


# 日期字段处理
def convert_dates(dataframe, date_columns):
    for column in date_columns:
        if column in dataframe.columns:
            dataframe[column] = pd.to_datetime(
                dataframe[column], format=datetime_format, errors="coerce", dayfirst=True
            ).fillna(pd.Timestamp('1970-01-01'))
    return dataframe


def snowflake_id_to_objectid(snowflake_id):
    # 将雪花 ID 转为字符串后，使用 MD5 生成一个 16 字节的哈希值
    hash_object = hashlib.md5(str(snowflake_id).encode())
    hex_string = hash_object.hexdigest()

    # 取前 24 个字符（保证是合法的 ObjectId 格式）
    object_id_hex = hex_string[:24]

    return ObjectId(object_id_hex)


# 清理 NaN 值，None和空字符串
def clean_nan_values(data):
    for record in data:
        keys_to_remove = []  # 记录需要删除的键
        for key, value in record.items():
            if value in ["", None] or pd.isna(value):  # 判断是否为空字符串、None 或 NaN
                keys_to_remove.append(key)  # 记录要删除的键
        for key in keys_to_remove:
            del record[key]  # 执行删除操作
    return data


# 设置默认属性值
def add_default_prop_values(data):
    for record in data:
        # 增加outEvalDate字段，值是maintainDate的日期（如果存在且有效的话）的后7天，并且去掉0时0分0秒，日期类型
        if "maintainDate" in record:
            maintain_date = record["maintainDate"]
            if isinstance(maintain_date, datetime):
                record["outEvalDate"] = maintain_date + pd.Timedelta(days=7)
                record["outEvalDate"] = record["outEvalDate"].replace(hour=0, minute=0, second=0, microsecond=0)

        # 增加serviceEvalId为null值的默认字段
        record["serviceEvalId"] = -1

    return data


# 生成唯一ID，带重试机制
def generate_unique_id(sf_gen, record):
    for attempt in range(ID_MAX_RETRIES):
        try:
            _id = next(sf_gen)
            return _id
        except Exception as e:
            logging.warning(f"生成唯一ID失败，重试次数: {attempt + 1}，记录: {record}，错误: {e}")
    return None


# 插入数据到 MongoDB
def insert_to_mongo(data, db, collection_name):
    collection = db[collection_name]
    requests = [InsertOne(record) for record in data]
    collection.bulk_write(requests)


# 更新进度
def update_progress(processed_rows, total_rows, batch_elapsed_time, thread_name, file_path):
    global global_progress
    with progress_lock:
        global_progress += processed_rows
        progress_percent = (global_progress / total_rows) * 100
        logging.info(
            f"线程 {thread_name} 处理文件 {file_path}：插入 {processed_rows} 条数据，耗时 {batch_elapsed_time:.2f} 秒，总进度：{global_progress}/{total_rows} ({progress_percent:.2f}%)"
        )


# 按文件分块处理
def process_file_with_error_handling(file_path, batch_size, db, collection_name, date_columns, total_rows, instance):
    sf_gen = SnowflakeGenerator(instance=instance)
    total_inserted = 0
    start_time = time.time()

    for chunk in pd.read_csv(file_path, chunksize=batch_size):
        batch_start_time = time.time()
        chunk = convert_dates(chunk, date_columns)
        data = chunk.to_dict(orient="records")
        data = biz_data_convert(data)
        data = clean_nan_values(data)
        data = add_default_prop_values(data)

        ids = [generate_unique_id(sf_gen, record) for record in data]
        data = [{**record, "_id": snowflake_id_to_objectid(_id)} for record, _id in zip(data, ids) if _id]

        if not data:
            continue

        try:
            insert_to_mongo(data, db, collection_name)
            batch_size = len(data)
            total_inserted += batch_size
            batch_elapsed_time = time.time() - batch_start_time
            update_progress(batch_size, total_rows, batch_elapsed_time, threading.current_thread().name, file_path)
        except BulkWriteError as bwe:
            logging.exception(f"MongoDB 批量插入错误: {bwe.details}")
            with failed_records_lock:
                all_failed_records.extend(data)
        except Exception as e:
            logging.exception(f"数据插入失败: {e}")
            with failed_records_lock:
                all_failed_records.extend(data)

    elapsed_time = time.time() - start_time
    logging.info(f"线程 {instance} 文件 {file_path} 总耗时 {elapsed_time:.2f} 秒，共插入 {total_inserted} 条数据。")


# 在所有线程完成后，统一将失败的记录写入错误 CSV 文件
def save_failed_records_to_csv():
    if all_failed_records:
        error_file_path = "all_failed_records.csv"
        # 使用 'utf-8-sig' 编码确保在Excel中打开时不会乱码
        with open(error_file_path, mode="w", encoding="utf-8-sig", newline="") as error_file:
            writer = csv.DictWriter(error_file, fieldnames=all_failed_records[0].keys())
            writer.writeheader()
            writer.writerows(all_failed_records)
        print(f"[错误] 所有失败记录已保存到 {error_file_path}")


# 按线程数拆分文件
def split_csv_file_by_threads(csv_file, threads, output_dir, total_rows):
    if os.path.exists(output_dir) and os.listdir(output_dir):
        logging.info(f"拆分目录 {output_dir} 已存在，跳过拆分。")
        return [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".csv")]

    os.makedirs(output_dir, exist_ok=True)
    rows_per_file = math.ceil(total_rows / threads)

    part_files = []
    try:
        for i, chunk in enumerate(pd.read_csv(csv_file, chunksize=rows_per_file, dtype=dtype)):
            part_file = os.path.join(output_dir, f"part_{i + 1}.csv")
            chunk.to_csv(part_file, index=False, header=True)
            part_files.append(part_file)
            logging.info(f"拆分文件 {part_file} 包含 {len(chunk)} 行。")
    except Exception as e:
        logging.error(f"拆分 CSV 文件错误: {e}")
        return []

    logging.info(f"拆分完成，共生成 {len(part_files)} 个文件。")
    return part_files


def count_rows_exclude_empty(file_path):
    """
    计算 CSV 文件的行数，排除空行和表头。
    """
    with open(file_path, 'r', encoding='utf-8') as csv_file:
        # 使用 enumerate 跳过表头，并排除空行
        total_rows = sum(1 for line_number, line in enumerate(csv_file, start=1) if line.strip() and line_number > 1)
    return total_rows


# 主逻辑
def main():
    parser = argparse.ArgumentParser(description="将 CSV 数据分批导入 MongoDB")
    parser.add_argument("csv_file", type=str, help="CSV 文件路径")
    parser.add_argument("--batch_size", type=int, default=5000, help="每批次处理的数据量（默认：5000）")
    parser.add_argument("--config", type=str, default=config_file_path, help="配置文件路径")
    parser.add_argument("--threads", type=int, default=10, help="线程数（默认:10）")
    parser.add_argument("--output_dir", type=str, default="./csv_parts", help="拆分后的文件输出目录")
    args = parser.parse_args()

    config = read_config(args.config)
    mongo_uri = config["MongoDB"]["uri"]
    database_name = config["MongoDB"]["database"]
    collection_name = config["MongoDB"]["collection"]
    date_columns = [col.strip() for col in config["CSV"]["date_columns"].split(",") if col.strip()]
    # 从配置文件中读取 "columns_to_ignore" 字段，去除前后空白，按逗号分割并过滤掉空列
    columns_to_ignore.extend([col.strip() for col in config["CSV"]["columns_to_ignore"].split(",") if col.strip()])

    client = MongoClient(mongo_uri)
    db = client[database_name]

    # 获取 CSV 文件总行数，减去表头行数
    total_rows = count_rows_exclude_empty(args.csv_file)
    part_files = split_csv_file_by_threads(args.csv_file, args.threads, args.output_dir, total_rows)

    start_time = time.time()
    with ThreadPoolExecutor(max_workers=args.threads) as executor:
        futures = [
            executor.submit(
                process_file_with_error_handling,
                part_file,
                args.batch_size,
                db,
                collection_name,
                date_columns,
                total_rows,
                instance=i,  # 分配给每个线程的唯一实例号
            )
            for i, part_file in enumerate(part_files)
        ]

        for future in futures:
            future.result()  # 等待每个线程完成

    elapsed_time = time.time() - start_time
    print(f"[完成] 全部数据导入完成，耗时 {elapsed_time:.2f} 秒。")

    # 保存失败记录到 CSV 文件
    save_failed_records_to_csv()


if __name__ == "__main__":
    main()
